#coding = utf-8
#基于mnist数据集的RCNN
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import rnn

mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
testimgs = mnist.test.images
testlabels = mnist.test.labels

n_classes = 10
#数据预处理
diminput = 28
dimhidden = 128
dimoutput = n_classes
nsteps = 28
weights = {
    'hidden': tf.Variable(tf.random_normal([diminput,dimhidden])),
    'out': tf.Variable(tf.random_normal([dimhidden, dimoutput]))
}
biases = {
    'hidden': tf.Variable(tf.random_normal([dimhidden])),
    'out': tf.Variable(tf.random_normal([dimoutput]))
}

def _RNN(_X, _W, _b, _nsteps, _name):
    # 1. Permute input from [batchsize, nsteps, diminput]
    # => [nsteps, batchsize, diminput]
    _X = tf.transpose(_X, [1, 0, 2])
    # 2. Reshape input to [nsteps*batchsize, diminput]
    _X = tf.reshape(_X, [-1, diminput])
    # 3. Input layer => Hidden layer
    _H = tf.matmul(_X, _W['hidden']) + _b['hidden']
    # 4. Splite data to 'nsteps' chunks. An i-th chunck indicates i-th batch data.
    #意思就是共享卷积，然后再把序列分开成单个
    _Hsplit = tf.split(_H, _nsteps, 0)
    # 5. Get LSTM's final output (_LSTM_O) and state (_LSTM_S)
    # Both _LSTM_O and _LSTM_S consist of 'batchsize' elements
    # Only _LSTM_O will be used to predict the output.
    with tf.variable_scope(_name) as scope:
        #变量重用
        #scope.reuse_variables()
        lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden, forget_bias=1.0)
        _LSTM_O, _LSTM_S = rnn.static_rnn(lstm_cell, _Hsplit, dtype=tf.float32)
    #6. Output
    _O = tf.matmul(_LSTM_O[-1], _W['out']) + _b['out']
    return {
        'X': _X, 'H': _H, 'Hsplit': _Hsplit,
        'LSTM_O': _LSTM_O, 'LSTM_S': _LSTM_S, 'O': _O
    }
print('Network ready')

learning_rate = 0.001
x = tf.placeholder("float", [None, nsteps, diminput])
y = tf.placeholder("float", [None, dimoutput])
keepratio = tf.placeholder(tf.float32)
myrnn = _RNN(x, weights, biases, nsteps, 'basic')
pred = myrnn['O']
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y, 1), logits=pred))
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
_corr = tf.equal(tf.arg_max(pred, 1), tf.arg_max(y, 1))
accr = tf.reduce_mean(tf.cast(_corr, tf.float32))
init = tf.global_variables_initializer()
print("Network Ready")

sess = tf.Session()
sess.run(init)

#迭代次数
training_epochs = 5
#每次迭代选择的样本数
batch_size = 16
ntest = mnist.test.images.shape[0]
#展示
display_step = 1
sess = tf.Session()
sess.run(init)
for epoch in range(training_epochs):
    avg_cost = 0
    #num_batch = int(mnist.train.num_examples/batch_size)
    num_batch = 100
    for i in range(num_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
        sess.run(optm, feed_dict={x:batch_xs, y:batch_ys, keepratio:0.7})
        feeds = {x:batch_xs, y:batch_ys,keepratio:1.}
        avg_cost += sess.run(cost, feed_dict=feeds)/num_batch
    #DISPLAY
    if epoch % display_step == 0:
        print("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
        feeds = {x:batch_xs, y:batch_ys}
        train_acc = sess.run(accr, feed_dict=feeds)
        print("Training accuracy: %.3f" % (train_acc))
        testimgs = testimgs.reshape((ntest, nsteps, diminput))
        feeds = {x:testimgs, y:testlabels}
        test_acc = sess.run(accr, feed_dict=feeds)
        print("Test accuracy: %.3f" % test_acc)
print('Done!')

